from model import LSTM_STGCN
from skimage.metrics import normalized_root_mse as compare_nrmse
import pandas as pd
from torch.autograd import Variable
from pyhht.emd import EMD
import os
from scipy.spatial.distance import pdist
from scipy.spatial.distance import squareform
import torch.nn as nn
from utils import *


def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = True

def normalize(data):
    """
    Normalising the data
    :param data: original data
    :return: normalised data and normalisation parameters
    """
    max_val = max(data)
    min_val = min(data)
    diff = max_val - min_val
    norm_data = [(x - min_val) / diff for x in data]
    norm_data = np.array(norm_data)
    norm_data = np.expand_dims(norm_data, axis=1)
    norm_param = {'max_val': max_val, 'min_val': min_val}
    return norm_data, norm_param

def denormalize(norm_data, norm_param):
    """
    Inverse normalisation of the data
    :param norm_data: normalised data
    :param norm_param: normalisation parameters
    :return: data after inverse normalisation
    """
    max_val = norm_param['max_val']
    min_val = norm_param['min_val']
    diff = max_val - min_val
    data = [x * diff + min_val for x in norm_data]
    return np.array(data)

def load_data(df, len_train, len_val):
    train = df[: len_train]
    val = df[len_train: len_train + len_val]
    test = df[len_train + len_val:]
    return train, val, test

set_seed(512)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

df = pd.read_csv('data/exchange_rate.csv', encoding='gb2312')
print(df.columns)
label = df['USDCNY'].values
label = np.log10(label)
label = label[0:731]
data = np.expand_dims(label, axis=1)
print(data.shape)

data = np.log10(data)
data = np.nan_to_num(data)
data = data[0:731]
data_EMD = np.empty((data.shape[0], 0))
for i in range(data.shape[1]):
    decomposer = EMD(data[:, i])
    imfs = decomposer.decompose().T
    data_EMD = np.concatenate([data_EMD, imfs], axis=1)
data = np.concatenate([data, data_EMD], axis=1)

## data normalisation
data_EMD_normalize = np.empty((data.shape[0], 0))
for i in range(data.shape[1]):
    norm_data, _ = normalize(data[:, i])
    data_EMD_normalize = np.concatenate([data_EMD_normalize, norm_data], axis=1)
label_normalize_data, label_norm_param = normalize(label)

## Initial parameters
n_his = 3
n_route = 16
Ks, Kt = 1, 1
blocks = [[1, 32, 64], [64, 32, 128]]
drop_prob = 0.1
train_size = 512
batch_size = 64
epochs = 200
lr = 0.0012
save_path = "save/model.pt"


distA = pdist(data_EMD_normalize.T, metric='canberra') # canberra,euclidean
# Turn the distA array into a matrix
W = squareform(distA)
L = scaled_laplacian(W)
Lk = cheb_poly(L, Ks)
Lk = torch.Tensor(Lk.astype(np.float32)).to(device)

# data spliting for train and test dataset
def split_data(data, label, seq_length, train_size, device):
    x = []
    y = []

    for i in range(len(data) - seq_length - 1):
        _x = data[i:(i + seq_length)]
        _y = label[i + seq_length][0]
        x.append(_x)
        y.append(_y)

    trainX = Variable(torch.Tensor(np.expand_dims(np.array(x[0:train_size]), axis=1)).to(device))
    trainY = Variable(torch.Tensor(np.array(y[0:train_size]).reshape(-1, 1)).to(device))

    testX = Variable(torch.Tensor(np.expand_dims(np.array(x[train_size:len(x)]), axis=1)).to(device))
    testY = Variable(torch.Tensor(np.array(y[train_size:len(y)]).reshape(-1, 1)).to(device))

    return trainX, trainY, testX, testY

x_train, y_train, x_test, y_test = split_data(data_EMD_normalize, label_normalize_data, n_his, train_size, device)

train_data = torch.utils.data.TensorDataset(x_train, y_train)
train_iter = torch.utils.data.DataLoader(train_data, batch_size, shuffle=True)

test_data = torch.utils.data.TensorDataset(x_test, y_test)
test_iter = torch.utils.data.DataLoader(test_data, batch_size)
val_data = test_data
val_iter = test_iter

loss = nn.MSELoss()
model = LSTM_STGCN(Ks, Kt, blocks, n_his, n_route, Lk, drop_prob, x_train.shape[-1], device).to(device)
optimizer = torch.optim.RMSprop(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.7)

min_val_loss = np.inf
for epoch in range(1, epochs + 1):
    l_sum, n = 0.0, 0
    model.train()
    for x, y in train_iter:
        y_pred = model(x).view(len(x), -1)
        l = loss(y_pred, y)
        optimizer.zero_grad()
        l.backward()
        optimizer.step()
        l_sum += l.item() * y.shape[0]
        n += y.shape[0]
    scheduler.step()
    val_loss = evaluate_model(model, loss, val_iter)
    if val_loss < min_val_loss:
        min_val_loss = val_loss
        print("save model")
        torch.save(model.state_dict(), save_path)
    print("epoch", epoch, ", train loss:", l_sum / n, ", validation loss:", val_loss)

# Model training
best_model = LSTM_STGCN(Ks, Kt, blocks, n_his, n_route, Lk, drop_prob, x_train.shape[-1],device).to(device)
best_model.load_state_dict(torch.load(save_path))

# Calculation of evaluation indicators
def evaluate_metric_v(model, data_iter, label_norm_param):
    model.eval()
    y_true_out = np.empty(0, )
    y_pred_out = np.empty(0, )
    with torch.no_grad():
        mae, mape, mse = [], [], []
        for x, y in data_iter:
            y = denormalize(y.cpu().numpy(), label_norm_param)
            y_true_out = np.concatenate((y_true_out, np.squeeze(y)), axis=0)
            y_pred = denormalize(model(x).view(len(x), -1).cpu().numpy(), label_norm_param)
            y_pred_out = np.concatenate((y_pred_out, np.squeeze(y_pred)), axis=0)

            d = np.abs(y - y_pred)
            mae += d.tolist()
            mape += (d / y).tolist()
            mse += (d ** 2).tolist()
        MAE = np.array(mae).mean()
        MAPE = np.array(mape).mean()
        RMSE = np.sqrt(np.array(mse).mean())

        return MAE, MAPE, RMSE, y_true_out, y_pred_out

l = evaluate_model(best_model, loss, test_iter)
MAE, MAPE, RMSE, y_true, y_pred = evaluate_metric_v(best_model, test_iter, label_norm_param)
print("test loss:", l, "\nMAE:", MAE, ",MAPE(%):", MAPE*100, ",RMSE:", RMSE)
print("NRMSE =", compare_nrmse(y_true, y_pred) * 100)
print("RMSE =", RMSE)
print("MAPE(%)= ", MAPE*100)